/*
 * Copyright (c) 2021, Peter Abeles. All Rights Reserved.
 *
 * This file is part of BoofCV (http://boofcv.org).
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package boofcv.alg.filter.convolve.border;

import boofcv.generate.AutoTypeImage;
import boofcv.generate.CodeGeneratorBase;

import java.io.FileNotFoundException;

/**
 * Code generator
 *
 * @author Peter Abeles
 */
public class GenerateConvolveJustBorder_General_SB extends CodeGeneratorBase {

	AutoTypeImage imageOut;
	String typeKernel;
	String typeInput;
	String dataKernel;
	String sumType;

	@Override
	public void generateCode() throws FileNotFoundException {

		printPreamble();

		createMethod(AutoTypeImage.I16,false);
		createMethod(AutoTypeImage.S32,false);
		createMethod(AutoTypeImage.F32,false);
		createMethod(AutoTypeImage.F64,false);

		createMethod(AutoTypeImage.I8,true);
		createMethod(AutoTypeImage.I16,true);
		createMethod(AutoTypeImage.S32,true);
		createMethod(AutoTypeImage.F32,true);

//		createBound(AutoTypeImage.F32);
//		createBound(AutoTypeImage.I8);

		out.println("}");
	}

	public void createMethod( AutoTypeImage imageOut , boolean divide ) {
		this.imageOut = imageOut;

		typeKernel = imageOut.getKernelType();
		dataKernel = imageOut.getKernelDataType();
		typeInput = "ImageBorder_"+typeKernel;

		sumType = imageOut.getSumType();
		String typeCast = imageOut.getTypeCastFromSum();

		addHorizontal(typeCast,divide);
		addVertical(typeCast,divide);
		addConvolution(typeCast,divide);
	}

	public void createBound( AutoTypeImage imageOut ) {
		this.imageOut = imageOut;

		typeKernel = imageOut.getKernelType();
		dataKernel = imageOut.getKernelDataType();
		typeInput = "ImageBorder_"+typeKernel;

		sumType = imageOut.getSumType();
		String typeCast = imageOut.getTypeCastFromSum();

		addConvolutionBound(typeCast);
	}

	public void printPreamble() {
		out.print("import boofcv.struct.border.ImageBorder_F32;\n" +
				"import boofcv.struct.border.ImageBorder_S32;\n" +
				"import boofcv.struct.border.ImageBorder_F64;\n" +
				"import boofcv.struct.convolve.*;\n" +
				"import boofcv.struct.image.*;\n" +
				"\n" +
				"/**\n" +
				" * <p>\n" +
				" * Convolves just the image's border. How the border condition is handled is specified by the {@link boofcv.struct.border.ImageBorder}\n" +
				" * passed in. For 1D kernels only the horizontal or vertical borders are processed.\n" +
				" * </p>\n" +
				" * \n" +
				" * <p>\n" +
				" * WARNING: Do not modify. Automatically generated by "+getClass().getSimpleName()+".\n" +
				" * </p>\n" +
				" * \n" +
				" * @author Peter Abeles\n" +
				" */\n" +
				"public class "+className+" {\n\n");
	}

	public void addHorizontal( String typeCast , boolean divide ) {
		String typeOutput = imageOut.getSingleBandName();
		String dataOutput = imageOut.getDataType();
		String divideHeader = divide ? ", int divisor" : "";
		String strTotal = divide ? "((total+halfDivisor)/divisor)" : "total";
		String divideHalf = divide ? "\t\tfinal int halfDivisor = divisor/2;\n" : "";

		out.print("\tpublic static void horizontal(Kernel1D_"+typeKernel+" kernel, "+typeInput+" input, "+typeOutput+" output "+divideHeader+") {\n" +
				"\t\tfinal "+dataOutput+"[] dataDst = output.data;\n" +
				"\t\tfinal "+dataKernel+"[] dataKer = kernel.data;\n" +
				"\n" +
				"\t\tfinal int offset = kernel.getOffset();\n" +
				"\t\tfinal int kernelWidth = kernel.getWidth();\n" +
				"\t\tfinal int width = output.getWidth();\n" +
				"\t\tfinal int height = output.getHeight();\n" +
				"\t\tfinal int borderRight = kernelWidth-offset-1;\n" +
				divideHalf +
				"\n" +
				"\t\tfor (int y = 0; y < height; y++) {\n" +
				"\t\t\tint indexDest = output.startIndex + y * output.stride;\n" +
				"\n" +
				"\t\t\tfor ( int x = 0; x < offset; x++ ) {\n" +
				"\t\t\t\t"+sumType+" total = 0;\n" +
				"\t\t\t\tfor (int k = 0; k < kernelWidth; k++) {\n" +
				"\t\t\t\t\ttotal += input.get(x+k-offset,y) * dataKer[k];\n" +
				"\t\t\t\t}\n" +
				"\t\t\t\tdataDst[indexDest++] = "+typeCast+""+strTotal+";\n" +
				"\t\t\t}\n" +
				"\n" +
				"\t\t\tindexDest = output.startIndex + y * output.stride + width-borderRight;\n" +
				"\t\t\tfor ( int x = width-borderRight; x < width; x++ ) {\n" +
				"\t\t\t\t"+sumType+" total = 0;\n" +
				"\t\t\t\tfor (int k = 0; k < kernelWidth; k++) {\n" +
				"\t\t\t\t\ttotal += input.get(x+k-offset,y) * dataKer[k];\n" +
				"\t\t\t\t}\n" +
				"\t\t\t\tdataDst[indexDest++] = "+typeCast+""+strTotal+";\n" +
				"\t\t\t}\n" +
				"\t\t}\n" +
				"\t}\n\n");
	}

	public void addVertical( String typeCast , boolean divide ) {
		String typeOutput = imageOut.getSingleBandName();
		String dataOutput = imageOut.getDataType();
		String divideHeader = divide ? ", int divisor" : "";
		String strTotal = divide ? "((total+halfDivisor)/divisor)" : "total";
		String divideHalf = divide ? "\t\tfinal int halfDivisor = divisor/2;\n" : "";

		out.print("\tpublic static void vertical(Kernel1D_"+typeKernel+" kernel, "+typeInput+" input, "+typeOutput+" output "+divideHeader+") {\n" +
				"\t\tfinal "+dataOutput+"[] dataDst = output.data;\n" +
				"\t\tfinal "+dataKernel+"[] dataKer = kernel.data;\n" +
				"\n" +
				"\t\tfinal int offset = kernel.getOffset();\n" +
				"\t\tfinal int kernelWidth = kernel.getWidth();\n" +
				"\t\tfinal int width = output.getWidth();\n" +
				"\t\tfinal int height = output.getHeight();\n" +
				"\t\tfinal int borderBottom = kernelWidth-offset-1;\n" +
				divideHalf +
				"\n" +
				"\t\tfor ( int x = 0; x < width; x++ ) {\n" +
				"\t\t\tint indexDest = output.startIndex + x;\n" +
				"\n" +
				"\t\t\tfor (int y = 0; y < offset; y++, indexDest += output.stride) {\n" +
				"\t\t\t\t"+sumType+" total = 0;\n" +
				"\t\t\t\tfor (int k = 0; k < kernelWidth; k++) {\n" +
				"\t\t\t\t\ttotal += input.get(x,y+k-offset) * dataKer[k];\n" +
				"\t\t\t\t}\n" +
				"\t\t\t\tdataDst[indexDest] = "+typeCast+""+strTotal+";\n" +
				"\t\t\t}\n" +
				"\n" +
				"\t\t\tindexDest = output.startIndex + (height-borderBottom) * output.stride + x;\n" +
				"\t\t\tfor (int y = height-borderBottom; y < height; y++, indexDest += output.stride) {\n" +
				"\t\t\t\t"+sumType+" total = 0;\n" +
				"\t\t\t\tfor (int k = 0; k < kernelWidth; k++ ) {\n" +
				"\t\t\t\t\ttotal += input.get(x,y+k-offset) * dataKer[k];\n" +
				"\t\t\t\t}\n" +
				"\t\t\t\tdataDst[indexDest] = "+typeCast+""+strTotal+";\n" +
				"\t\t\t}\n" +
				"\t\t}\n" +
				"\t}\n\n");
	}

	public void addConvolution( String typeCast , boolean divide ) {
		String typeOutput = imageOut.getSingleBandName();
		String dataOutput = imageOut.getDataType();
		String divideHeader = divide ? ", int divisor" : "";
		String strTotal = divide ? "((total+halfDivisor)/divisor)" : "total";
		String divideHalf = divide ? "\t\tfinal int halfDivisor = divisor/2;\n" : "";

		out.print("\tpublic static void convolve(Kernel2D_"+typeKernel+" kernel, "+typeInput+" input, "+typeOutput+" output "+divideHeader+") {\n" +
				"\t\tfinal "+dataOutput+"[] dataDst = output.data;\n" +
				"\t\tfinal "+dataKernel+"[] dataKer = kernel.data;\n" +
				"\n" +
				"\t\tfinal int offsetL = kernel.getOffset();\n" +
				"\t\tfinal int offsetR = kernel.getWidth()-offsetL-1;\n"+
				"\t\tfinal int width = output.getWidth();\n" +
				"\t\tfinal int height = output.getHeight();\n" +
				divideHalf +
				"\n" +
				"\t\t// convolve along the left and right borders\n" +
				"\t\tfor (int y = 0; y < height; y++) {\n" +
				"\t\t\tint indexDest = output.startIndex + y * output.stride;\n" +
				"\n" +
				"\t\t\tfor ( int x = 0; x < offsetL; x++ ) {\n" +
				"\t\t\t\t"+sumType+" total = 0;\n" +
				"\t\t\t\tint indexKer = 0;\n" +
				"\t\t\t\tfor( int i = -offsetL; i <= offsetR; i++ ) {\n" +
				"\t\t\t\t\tfor (int j = -offsetL; j <= offsetR; j++) {\n" +
				"\t\t\t\t\t\ttotal += input.get(x+j,y+i) * dataKer[indexKer++];\n" +
				"\t\t\t\t\t}\n" +
				"\t\t\t\t}\n" +
				"\t\t\t\tdataDst[indexDest++] = "+typeCast+""+strTotal+";\n" +
				"\t\t\t}\n" +
				"\n" +
				"\t\t\tindexDest = output.startIndex + y * output.stride + width-offsetR;\n" +
				"\t\t\tfor ( int x = width-offsetR; x < width; x++ ) {\n" +
				"\t\t\t\t"+sumType+" total = 0;\n" +
				"\t\t\t\tint indexKer = 0;\n" +
				"\t\t\t\tfor( int i = -offsetL; i <= offsetR; i++ ) {\n" +
				"\t\t\t\t\tfor (int j = -offsetL; j <= offsetR; j++) {\n" +
				"\t\t\t\t\t\ttotal += input.get(x+j,y+i) * dataKer[indexKer++];\n" +
				"\t\t\t\t\t}\n" +
				"\t\t\t\t}\n" +
				"\t\t\t\tdataDst[indexDest++] = "+typeCast+""+strTotal+";\n" +
				"\t\t\t}\n" +
				"\t\t}\n" +
				"\n" +
				"\t\t// convolve along the top and bottom borders\n" +
				"\t\tfor ( int x = offsetL; x < width-offsetR; x++ ) {\n" +
				"\t\t\tint indexDest = output.startIndex + x;\n" +
				"\n" +
				"\t\t\tfor (int y = 0; y < offsetL; y++, indexDest += output.stride) {\n" +
				"\t\t\t\t"+sumType+" total = 0;\n" +
				"\t\t\t\tint indexKer = 0;\n" +
				"\t\t\t\tfor( int i = -offsetL; i <= offsetR; i++ ) {\n" +
				"\t\t\t\t\tfor (int j = -offsetL; j <= offsetR; j++) {\n" +
				"\t\t\t\t\t\ttotal += input.get(x+j,y+i) * dataKer[indexKer++];\n" +
				"\t\t\t\t\t}\n" +
				"\t\t\t\t}\n" +
				"\t\t\t\tdataDst[indexDest] = "+typeCast+""+strTotal+";\n" +
				"\t\t\t}\n" +
				"\n" +
				"\t\t\tindexDest = output.startIndex + (height-offsetR) * output.stride + x;\n" +
				"\t\t\tfor (int y = height-offsetR; y < height; y++, indexDest += output.stride) {\n" +
				"\t\t\t\t"+sumType+" total = 0;\n" +
				"\t\t\t\tint indexKer = 0;\n" +
				"\t\t\t\tfor( int i = -offsetL; i <= offsetR; i++ ) {\n" +
				"\t\t\t\t\tfor (int j = -offsetL; j <= offsetR; j++) {\n" +
				"\t\t\t\t\t\ttotal += input.get(x+j,y+i) * dataKer[indexKer++];\n" +
				"\t\t\t\t\t}\n" +
				"\t\t\t\t}\n" +
				"\t\t\t\tdataDst[indexDest] = "+typeCast+""+strTotal+";\n" +
				"\t\t\t}\n" +
				"\t\t}\n" +
				"\t}\n\n");
	}

	public void addConvolutionBound( String typeCast ) {
		String typeOutput = imageOut.getSingleBandName();
		String dataOutput = imageOut.getDataType();

		out.print("\tpublic static void convolve(Kernel2D_"+typeKernel+" kernel, "+typeInput+" input, "+typeOutput+" output ,\n" +
				"\t\t\t\t\t\t\t\t"+sumType+" minValue , "+sumType+" maxValue ) {\n" +
				"\t\tfinal "+dataOutput+"[] dataDst = output.data;\n" +
				"\t\tfinal "+dataKernel+"[] dataKer = kernel.data;\n" +
				"\n" +
				"\t\tfinal int offsetL = kernel.getOffset();\n" +
				"\t\tfinal int offsetR = kernel.getWidth()-offsetL-1;\n"+
				"\t\tfinal int width = output.getWidth();\n" +
				"\t\tfinal int height = output.getHeight();\n" +
				"\n" +
				"\t\t// convolve along the left and right borders\n" +
				"\t\tfor (int y = 0; y < height; y++) {\n" +
				"\t\t\tint indexDest = output.startIndex + y * output.stride;\n" +
				"\n" +
				"\t\t\tfor ( int x = 0; x < offsetL; x++ ) {\n" +
				"\t\t\t\t"+sumType+" total = 0;\n" +
				"\t\t\t\tint indexKer = 0;\n" +
				"\t\t\t\tfor( int i = -offsetL; i <= offsetR; i++ ) {\n" +
				"\t\t\t\t\tfor (int j = -offsetL; j <= offsetR; j++) {\n" +
				"\t\t\t\t\t\ttotal += input.get(x+j,y+i) * dataKer[indexKer++];\n" +
				"\t\t\t\t\t}\n" +
				"\t\t\t\t}\n" +
				"\n" +
				"\t\t\t\tif( total < minValue )\n" +
				"\t\t\t\t\ttotal = minValue;\n" +
				"\t\t\t\telse if( total > maxValue )\n" +
				"\t\t\t\t\ttotal = maxValue;\n" +
				"\n"+
				"\t\t\t\tdataDst[indexDest++] = "+typeCast+"total;\n" +
				"\t\t\t}\n" +
				"\n" +
				"\t\t\tindexDest = output.startIndex + y * output.stride + width-offsetR;\n" +
				"\t\t\tfor ( int x = width-offsetR; x < width; x++ ) {\n" +
				"\t\t\t\t"+sumType+" total = 0;\n" +
				"\t\t\t\tint indexKer = 0;\n" +
				"\t\t\t\tfor( int i = -offsetL; i <= offsetR; i++ ) {\n" +
				"\t\t\t\t\tfor (int j = -offsetL; j <= offsetR; j++) {\n" +
				"\t\t\t\t\t\ttotal += input.get(x+j,y+i) * dataKer[indexKer++];\n" +
				"\t\t\t\t\t}\n" +
				"\t\t\t\t}\n" +
				"\n" +
				"\t\t\t\tif( total < minValue )\n" +
				"\t\t\t\t\ttotal = minValue;\n" +
				"\t\t\t\telse if( total > maxValue )\n" +
				"\t\t\t\t\ttotal = maxValue;\n" +
				"\n"+
				"\t\t\t\tdataDst[indexDest++] = "+typeCast+"total;\n" +
				"\t\t\t}\n" +
				"\t\t}\n" +
				"\n" +
				"\t\t// convolve along the top and bottom borders\n" +
				"\t\tfor ( int x = offsetL; x < width-offsetR; x++ ) {\n" +
				"\t\t\tint indexDest = output.startIndex + x;\n" +
				"\n" +
				"\t\t\tfor (int y = 0; y < offsetL; y++, indexDest += output.stride) {\n" +
				"\t\t\t\t"+sumType+" total = 0;\n" +
				"\t\t\t\tint indexKer = 0;\n" +
				"\t\t\t\tfor( int i = -offsetL; i <= offsetR; i++ ) {\n" +
				"\t\t\t\t\tfor (int j = -offsetL; j <= offsetR; j++) {\n" +
				"\t\t\t\t\t\ttotal += input.get(x+j,y+i) * dataKer[indexKer++];\n" +
				"\t\t\t\t\t}\n" +
				"\t\t\t\t}\n" +
				"\n" +
				"\t\t\t\tif( total < minValue )\n" +
				"\t\t\t\t\ttotal = minValue;\n" +
				"\t\t\t\telse if( total > maxValue )\n" +
				"\t\t\t\t\ttotal = maxValue;\n" +
				"\n"+
				"\t\t\t\tdataDst[indexDest] = "+typeCast+"total;\n" +
				"\t\t\t}\n" +
				"\n" +
				"\t\t\tindexDest = output.startIndex + (height-offsetR) * output.stride + x;\n" +
				"\t\t\tfor (int y = height-offsetR; y < height; y++, indexDest += output.stride) {\n" +
				"\t\t\t\t"+sumType+" total = 0;\n" +
				"\t\t\t\tint indexKer = 0;\n" +
				"\t\t\t\tfor( int i = -offsetL; i <= offsetR; i++ ) {\n" +
				"\t\t\t\t\tfor (int j = -offsetL; j <= offsetR; j++) {\n" +
				"\t\t\t\t\t\ttotal += input.get(x+j,y+i) * dataKer[indexKer++];\n" +
				"\t\t\t\t\t}\n" +
				"\t\t\t\t}\n" +
				"\n" +
				"\t\t\t\tif( total < minValue )\n" +
				"\t\t\t\t\ttotal = minValue;\n" +
				"\t\t\t\telse if( total > maxValue )\n" +
				"\t\t\t\t\ttotal = maxValue;\n" +
				"\n"+
				"\t\t\t\tdataDst[indexDest] = "+typeCast+"total;\n" +
				"\t\t\t}\n" +
				"\t\t}\n" +
				"\t}\n\n");
	}

	public static void main( String[] args ) throws FileNotFoundException {
		GenerateConvolveJustBorder_General_SB generator = new GenerateConvolveJustBorder_General_SB();
		generator.generateCode();
	}
}
